{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "dotnet_interactive": {
     "language": "csharp"
    },
    "vscode": {
     "languageId": "polyglot-notebook"
    }
   },
   "outputs": [],
   "source": [
    "#r \"nuget: TorchSharp-cpu\"\n",
    "\n",
    "using TorchSharp;\n",
    "using static TorchSharp.torch;\n",
    "using static TorchSharp.TensorExtensionMethods;\n",
    "using static TorchSharp.torch.distributions;"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training with a Learning Rate Scheduler\n",
    "\n",
    "In Tutorial 6, we saw how the optimizers took an argument called the 'learning rate,' but didn't spend much time on it except to say that it could have a great impact on how quickly training would converge toward a solution. In fact, you can choose the learning rate (LR) so poorly, that the training doesn't converge at all.\n",
    "\n",
    "If the LR is too small, training will go very slowly, wasting compute resources. If it is too large, training could result in numeric overflow, or NaNs. Either way, you're in trouble."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To further complicate matters, it turns out that the learning rate shouldn't necessarily be constant. Training can go much better if the learning rate starts out relatively large and gets smaller as you get closer to the end.\n",
    "\n",
    "There's a solution for this, called a Learning Rate Scheduler. An LRS instance has access to the internal state of the optimizer, and can modify the LR as it goes along. Some schedulers modify other optimizer state, too, such as the momentum (for optimizers that use momentum).\n",
    "\n",
    "There are several algorithms for scheduling, and TorchSharp implements a number of them."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before demonstrating, let's have a model and a baseline training loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "dotnet_interactive": {
     "language": "csharp"
    },
    "vscode": {
     "languageId": "polyglot-notebook"
    }
   },
   "outputs": [],
   "source": [
    "private class Trivial : nn.Module<Tensor,Tensor>\n",
    "{\n",
    "    public Trivial()\n",
    "        : base(nameof(Trivial))\n",
    "    {\n",
    "        RegisterComponents();\n",
    "    }\n",
    "\n",
    "    public override Tensor forward(Tensor input)\n",
    "    {\n",
    "        using var x = lin1.forward(input);\n",
    "        using var y = nn.functional.relu(x);\n",
    "        return lin2.forward(y);\n",
    "    }\n",
    "\n",
    "    private nn.Module<Tensor,Tensor> lin1 = nn.Linear(1000, 100);\n",
    "    private nn.Module<Tensor,Tensor> lin2 = nn.Linear(100, 10);\n",
    "}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To demonstrate how to correctly use an LR scheduler, our training data needs to look more like real training data, that is, it needs to be divided into batches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "dotnet_interactive": {
     "language": "csharp"
    },
    "vscode": {
     "languageId": "polyglot-notebook"
    }
   },
   "outputs": [],
   "source": [
    "var learning_rate = 0.01f;\n",
    "var model = new Trivial();\n",
    "var loss = nn.MSELoss();\n",
    "\n",
    "var data = Enumerable.Range(0,16).Select(_ => rand(32,1000)).ToList<torch.Tensor>();  // Our pretend input data\n",
    "var results = Enumerable.Range(0,16).Select(_ => rand(32,10)).ToList<torch.Tensor>();  // Our pretend ground truth.\n",
    "\n",
    "var optimizer = torch.optim.SGD(model.parameters(), learning_rate);\n",
    "\n",
    "for (int i = 0; i < 300; i++) {\n",
    "\n",
    "    for (int idx = 0; i < data.Count; i++) {\n",
    "        // Compute the loss\n",
    "        using var output = loss.forward(model.forward(data[idx]), results[idx]);\n",
    "\n",
    "        // Clear the gradients before doing the back-propagation\n",
    "        model.zero_grad();\n",
    "\n",
    "        // Do back-progatation, which computes all the gradients.\n",
    "        output.backward();\n",
    "\n",
    "        optimizer.step();\n",
    "    }\n",
    "}\n",
    "\n",
    "loss.forward(model.forward(data[0]), results[0]).item<float>()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When I ran this, the loss was down to 0.095 after 1 second. (It took longer the first time around.)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## StepLR\n",
    "\n",
    "StepLR uses subtraction to adjust the learning rate every so often. The difference it makes to the training loop is that you wrap the optimizer, and then call `step` on the scheduler (once per epoch) as well as the optimizer (once per batch)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "dotnet_interactive": {
     "language": "csharp"
    },
    "vscode": {
     "languageId": "polyglot-notebook"
    }
   },
   "outputs": [],
   "source": [
    "var learning_rate = 0.01f;\n",
    "var model = new Trivial();\n",
    "var loss = nn.MSELoss();\n",
    "\n",
    "var data = Enumerable.Range(0,16).Select(_ => rand(32,1000)).ToList<torch.Tensor>();  // Our pretend input data\n",
    "var results = Enumerable.Range(0,16).Select(_ => rand(32,10)).ToList<torch.Tensor>();  // Our pretend ground truth.\n",
    "\n",
    "var optimizer = torch.optim.SGD(model.parameters(), learning_rate);\n",
    "var scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 25, 0.95);\n",
    "\n",
    "for (int i = 0; i < 300; i++) {\n",
    "\n",
    "    for (int idx = 0; i < data.Count; i++) {\n",
    "        // Compute the loss\n",
    "        using var output = loss.forward(model.forward(data[idx]), results[idx]);\n",
    "\n",
    "        // Clear the gradients before doing the back-propagation\n",
    "        model.zero_grad();\n",
    "\n",
    "        // Do back-progatation, which computes all the gradients.\n",
    "        output.backward();\n",
    "\n",
    "        optimizer.step();\n",
    "    }\n",
    "\n",
    "    scheduler.step();\n",
    "}\n",
    "\n",
    "loss.forward(model.forward(data[0]), results[0]).item<float>()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Well, that was underwhelming. The loss (in my case) went up a bit, so that's nothing to get excited about. For this trivial model, using a scheduler isn't going to make a huge difference, and it may not make much of a difference even for complex models. It's very hard to know until you try it, but now you know how to try it out. If you try this trivial example over and over, you will see that the results vary quite a bit. It's simply too simple.\n",
    "\n",
    "Regardless, you can see from the verbose output that the learning rate is adjusted as the epochs proceed. \n",
    "\n",
    "Note: If you're using 0.93.9 and you see odd dips in the learning rate, that's a bug in the verbose printout logic, not the learning rate scheduler itself."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".NET (C#)",
   "language": "C#",
   "name": ".net-csharp"
  },
  "language_info": {
   "name": "C#"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
